package com.jframe.sys.utils;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {
	HttpServletRequest orgRequest = null;

	public XssHttpServletRequestWrapper(HttpServletRequest request) {
		super(request);
		orgRequest = request;
	}
	
	 /** 
     * 覆盖getAttribute方法
     * @param name 
     * @return 
     */  
    public Object getAttribute(String name) {  
        Object value = super.getAttribute(name);  
        if (value instanceof String) {  
            value = xssEncode(String.valueOf(value));  
        }  
        return value;  
    }  

	/**
	 * 覆盖getParameter方法，将参数名和参数值都做xss过滤。<br/>
	 * 如果需要获得原始的值，则通过super.getParameterValues(name)来获取<br/>
	 * getParameterNames,getParameterValues和getParameterMap也可能需要覆盖
	 */
	@Override
	public String getParameter(String name) {
		String value = super.getParameter(xssEncode(name));
		if (value != null) {
			value = xssEncode(value);
		}
		return value;
	}

	/**
	 * 覆盖getHeader方法，将参数名和参数值都做xss过滤。<br/>
	 * 如果需要获得原始的值，则通过super.getHeaders(name)来获取<br/>
	 * getHeaderNames 也可能需要覆盖
	 */
	@Override
	public String getHeader(String name) {

		String value = super.getHeader(xssEncode(name));
		if (value != null) {
			value = xssEncode(value);
		}
		return value;
	}
	
	/** 
     * 
     * @param name 
     * @return 
     */  
    public String[] getParameterValues(String name) {  
        String[] values = super.getParameterValues(name);  
        if (values != null) {  
            for (int i = 0; i < values.length; i++) {  
                values[i] = xssEncode(values[i]);  
            }  
        }  
        return values;  
    }  
  
    /** 
     * @return 
     */  
    public Map<String, String[]> getParameterMap() {   
        HashMap<String, String[]> paramMap = (HashMap<String, String[]>) super.getParameterMap();  
        paramMap = (HashMap<String, String[]>) paramMap.clone();  
        for (Iterator iterator = paramMap.entrySet().iterator(); iterator.hasNext(); ) {  
            Map.Entry<String,String[]> entry = (Map.Entry<String,String[]>) iterator.next();  
            String [] values = entry.getValue();  
            for (int i = 0; i < values.length; i++) {  
                if(values[i] instanceof String){  
                    values[i] = xssEncode(values[i]);  
                }  
            }  
            entry.setValue(values);  
        }  
        return paramMap;  
    }  

	/**
	 * 将容易引起xss漏洞的半角字符直接替换成全角字符
	 * 
	 * @param s
	 * @return
	 */
	private static String xssEncode(String s) {
		if (s == null || s.isEmpty()) {
			return s;
		}
		StringBuilder sb = new StringBuilder(s.length() + 16);
		for (int i = 0; i < s.length(); i++) {
			char c = s.charAt(i);
			switch (c) {
			case '>':
				sb.append('＞');// 全角大于号
				break;
			case '<':
				sb.append('＜');// 全角小于号
				break;
			case '\'':
				sb.append('‘');// 全角单引号
				break;
			case '\"':
				sb.append('“');// 全角双引号
				break;
			case '&':
				sb.append('＆');// 全角
				break;
			case '\\':
				sb.append('＼');// 全角斜线
				break;
			case '#':
				sb.append('＃');// 全角井号
				break;
			default:
				sb.append(c);
				break;
			}
		}
		return sb.toString();
	}

	/**
	 * 获取最原始的request
	 * 
	 * @return
	 */
	public HttpServletRequest getOrgRequest() {
		return orgRequest;
	}

	/**
	 * 获取最原始的request的静态方法
	 * 
	 * @return
	 */
	public static HttpServletRequest getOrgRequest(HttpServletRequest req) {
		if (req instanceof XssHttpServletRequestWrapper) {
			return ((XssHttpServletRequestWrapper) req).getOrgRequest();
		}
		return req;
	}

}
